{
 "cells": [
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# 文档增强 RAG 与问题生成\n",
    "\n",
    "增强的 RAG 方法，通过文档增强和问题生成来改进。通过为每个文本块生成相关的问题，我们提高了检索过程，从而使得语言模型能够提供更好的响应。\n",
    "\n",
    "------\n",
    "实现步骤：\n",
    "- 数据采集：从 PDF 文件中提取文本。\n",
    "- 分块处理：将文本分割成可管理的块\n",
    "- **生成问题：为每个块生成相关的问题。**\n",
    "- 创建嵌入：为块和生成的问题创建嵌入。\n",
    "- 向量存储创建：使用 NumPy 构建简单的向量存储。\n",
    "- 语义搜索：检索与用户查询相关的片段和问题。\n",
    "- 响应生成：根据检索到的内容生成答案。\n",
    "- 评估：评估生成响应的质量。\n"
   ],
   "id": "c8c39e48fd835a85"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-24T06:21:10.159567Z",
     "start_time": "2025-04-24T06:21:07.753022Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import fitz\n",
    "import os\n",
    "import re\n",
    "import json\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from openai import OpenAI\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "load_dotenv()"
   ],
   "id": "c914c6c7566a1ea9",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 1
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-24T06:21:11.811290Z",
     "start_time": "2025-04-24T06:21:11.571970Z"
    }
   },
   "cell_type": "code",
   "source": [
    "client = OpenAI(\n",
    "    base_url=os.getenv(\"LLM_BASE_URL\"),\n",
    "    api_key=os.getenv(\"LLM_API_KEY\")\n",
    ")\n",
    "llm_model = os.getenv(\"LLM_MODEL_ID\")\n",
    "embedding_model = os.getenv(\"EMBEDDING_MODEL_ID\")\n",
    "\n",
    "pdf_path = \"../../data/AI_Information.en.zh-CN.pdf\""
   ],
   "id": "93abff4be7c8caff",
   "outputs": [],
   "execution_count": 2
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# 提取文本并识别章节标题",
   "id": "f100ee49730d6037"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-24T06:21:13.938852Z",
     "start_time": "2025-04-24T06:21:13.935880Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def extract_text_from_pdf(pdf_path):\n",
    "    \"\"\"\n",
    "    从 PDF 文件中提取文本，并打印前 `num_chars` 个字符。\n",
    "\n",
    "    Args:\n",
    "    pdf_path (str): Path to the PDF file.\n",
    "\n",
    "    Returns:\n",
    "    str: Extracted text from the PDF.\n",
    "    \"\"\"\n",
    "    # 打开 PDF 文件\n",
    "    mypdf = fitz.open(pdf_path)\n",
    "    all_text = \"\"  # 初始化一个空字符串以存储提取的文本\n",
    "\n",
    "    # Iterate through each page in the PDF\n",
    "    for page_num in range(mypdf.page_count):\n",
    "        page = mypdf[page_num]\n",
    "        text = page.get_text(\"text\")  # 从页面中提取文本\n",
    "        all_text += text  # 将提取的文本追加到 all_text 字符串中\n",
    "\n",
    "    return all_text  # 返回提取的文本"
   ],
   "id": "c8c58ccc4025e7fd",
   "outputs": [],
   "execution_count": 3
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# 对提取的文本进行分块",
   "id": "c71d183cb5f71d7a"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-24T06:21:16.321299Z",
     "start_time": "2025-04-24T06:21:16.318Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def chunk_text(text, n, overlap):\n",
    "    \"\"\"\n",
    "    将文本分割为重叠的块\n",
    "\n",
    "    Args:\n",
    "    text (str): 要分割的文本\n",
    "    n (int): 每个块的字符数\n",
    "    overlap (int): 块之间的重叠字符数\n",
    "\n",
    "    Returns:\n",
    "    List[str]: 文本块列表\n",
    "    \"\"\"\n",
    "    chunks = []  #\n",
    "    for i in range(0, len(text), n - overlap):\n",
    "        # 添加从当前索引到索引 + 块大小的文本块\n",
    "        chunks.append(text[i:i + n])\n",
    "\n",
    "    return chunks  # Return the list of text chunks"
   ],
   "id": "a318fd4af79dca2e",
   "outputs": [],
   "execution_count": 4
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# 为文本块生成问题\n",
    "这是对简单 RAG 的关键增强。我们为每个文本块生成可以通过该文本块回答的问题"
   ],
   "id": "877fc22c554384b2"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-24T06:21:19.378381Z",
     "start_time": "2025-04-24T06:21:19.374022Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def generate_questions(text_chunk, num_questions=5):\n",
    "    \"\"\"\n",
    "    生成可以从给定文本块中回答的相关问题。\n",
    "\n",
    "    Args:\n",
    "    text_chunk (str): 要生成问题的文本块。\n",
    "    num_questions (int): 要生成的问题数量。\n",
    "    model (str): 用于生成问题的模型。\n",
    "\n",
    "    Returns:\n",
    "    List[str]: 生成的问题列表。\n",
    "    \"\"\"\n",
    "    # 定义系统提示\n",
    "    system_prompt = \"你是一个从文本中生成相关问题的专家。能够根据用户提供的文本生成可回答的简洁问题，重点聚焦核心信息和关键概念。\"\n",
    "\n",
    "    # 定义用户提示，包含文本块和要生成的问题数量\n",
    "    # user_prompt = f\"\"\"\n",
    "    # 根据以下文本，生成 {num_questions} 个不同的问题，这些问题只能通过此文本回答：\n",
    "    #\n",
    "    # {text_chunk}\n",
    "    #\n",
    "    # 请以编号列表的形式回复问题，且不要添加任何额外文本。\n",
    "    # \"\"\"\n",
    "    user_prompt = f\"\"\"\n",
    "    请根据以下文本内容生成{num_questions}个不同的、仅能通过该文本内容回答的问题：\n",
    "\n",
    "    {text_chunk}\n",
    "\n",
    "    请严格按以下格式回复：\n",
    "    1. 带编号的问题列表\n",
    "    2. 仅包含问题\n",
    "    3. 不要添加任何其他内容\n",
    "    \"\"\"\n",
    "\n",
    "    # 使用 OpenAI API 生成问题\n",
    "    response = client.chat.completions.create(\n",
    "        model=llm_model,\n",
    "        temperature=0.7,\n",
    "        messages=[\n",
    "            {\"role\": \"system\", \"content\": system_prompt},\n",
    "            {\"role\": \"user\", \"content\": user_prompt}\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    # 从响应中提取并清理问题\n",
    "    questions_text = response.choices[0].message.content.strip()\n",
    "\n",
    "    # 使用正则表达式模式匹配提取问题\n",
    "    pattern = r'^\\d+\\.\\s*(.*)'\n",
    "    return [re.match(pattern, line).group(1) for line in questions_text.split('\\n') if line.strip()]\n",
    "    # 此处改变了原有的正则处理，避免生成的问题中没有问号以及中英文问号匹配的问题\n",
    "    # questions = []\n",
    "    # for line in questions_text.split('\\n'):\n",
    "    #     # 去除编号并清理空白\n",
    "    #     cleaned_line = re.sub(r'^\\d+\\.\\s*', '', line.strip())\n",
    "    #     # if cleaned_line and cleaned_line.endswith('？') or cleaned_line.endswith(\"?\"):\n",
    "    #     #     questions.append(cleaned_line)\n",
    "    #\n",
    "    # return questions\n"
   ],
   "id": "f29e616a6c4b90a2",
   "outputs": [],
   "execution_count": 5
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# 测试根据文本块生成的问题\n",
    "\n",
    "print(\"从 PDF 中提取文本...\")\n",
    "extracted_text = extract_text_from_pdf(pdf_path)\n",
    "\n",
    "print(\"分割文本...\")\n",
    "text_chunks = chunk_text(extracted_text, 1000, 200)\n",
    "print(f\"创建了 {len(text_chunks)} 个文本块\")\n",
    "\n",
    "print(\"处理文本块并生成问题...\")\n",
    "\n",
    "for i, chunk in enumerate(tqdm(text_chunks, desc=\"处理文本块\")):\n",
    "    questions = generate_questions(chunk, num_questions=5)\n",
    "    print(\"Generated Questions:\")\n",
    "    print(questions)\n",
    "    print(\"生成问题\".center(80, '-'))\n",
    "\n",
    "    if i == 0:\n",
    "        break"
   ],
   "id": "98c0eb1768acee69",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# 文本创建嵌入",
   "id": "3ff2344416153365"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-24T06:21:40.434740Z",
     "start_time": "2025-04-24T06:21:40.431085Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def create_embeddings(text):\n",
    "    \"\"\"\n",
    "    为给定文本创建嵌入向量，使用指定的 OpenAI 模型。\n",
    "\n",
    "    Args:\n",
    "    text (str): 要为其创建嵌入向量的输入文本。\n",
    "    model (str): 用于创建嵌入向量的模型。\n",
    "\n",
    "    Returns:\n",
    "    dict: 包含嵌入向量的 OpenAI API 响应。\n",
    "    \"\"\"\n",
    "    # 使用指定模型为输入文本创建嵌入向量\n",
    "    response = client.embeddings.create(\n",
    "        model=embedding_model,\n",
    "        input=text\n",
    "    )\n",
    "\n",
    "    return response  # 返回包含嵌入向量的响应\n"
   ],
   "id": "1ca735184eb5c179",
   "outputs": [],
   "execution_count": 6
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# 向量存储创建\n",
    "\n",
    "使用Numpy建立一个简单的向量存储"
   ],
   "id": "2ec522a9fe3a36ff"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-24T06:21:44.480354Z",
     "start_time": "2025-04-24T06:21:44.473658Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class SimpleVectorStore:\n",
    "    \"\"\"\n",
    "    使用 NumPy 实现的简单向量存储。\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        \"\"\"\n",
    "        初始化向量存储。\n",
    "        \"\"\"\n",
    "        self.vectors = []  # 存储向量\n",
    "        self.texts = []  # 存储原始文本\n",
    "        self.metadata = []  # 存储元数据\n",
    "\n",
    "    def add_item(self, text, embedding, metadata=None):\n",
    "        \"\"\"\n",
    "        向向量存储中添加一个项目。\n",
    "\n",
    "        Args:\n",
    "        text (str): 原始文本。\n",
    "        embedding (List[float]): 嵌入向量。\n",
    "        metadata (dict, optional): 额外的元数据。\n",
    "        \"\"\"\n",
    "        self.vectors.append(np.array(embedding))  # 将嵌入向量转换为 NumPy 数组并存储\n",
    "        self.texts.append(text)  # 存储原始文本\n",
    "        self.metadata.append(metadata or {})  # 存储元数据（如果未提供，则使用空字典）\n",
    "\n",
    "    def similarity_search(self, query_embedding, k=5):\n",
    "        \"\"\"\n",
    "        查找与查询嵌入向量最相似的项目。\n",
    "\n",
    "        Args:\n",
    "        query_embedding (List[float]): 查询嵌入向量。\n",
    "        k (int): 返回的结果数量。\n",
    "\n",
    "        Returns:\n",
    "        List[Dict]: 最相似的前 k 个项目及其文本和元数据。\n",
    "        \"\"\"\n",
    "        if not self.vectors:  # 如果向量列表为空，返回空列表\n",
    "            return []\n",
    "\n",
    "        # 将查询嵌入向量转换为 NumPy 数组\n",
    "        query_vector = np.array(query_embedding)\n",
    "\n",
    "        # 使用余弦相似度计算相似度\n",
    "        similarities = []\n",
    "        for i, vector in enumerate(self.vectors):  # 遍历所有向量\n",
    "            similarity = np.dot(query_vector, vector) / (np.linalg.norm(query_vector) * np.linalg.norm(vector))\n",
    "            similarities.append((i, similarity))  # 存储索引和相似度\n",
    "\n",
    "        # 按相似度降序排序\n",
    "        similarities.sort(key=lambda x: x[1], reverse=True)\n",
    "\n",
    "        # 返回前 k 个结果\n",
    "        results = []\n",
    "        for i in range(min(k, len(similarities))):  # 确保不会超出向量数量\n",
    "            idx, score = similarities[i]\n",
    "            results.append({\n",
    "                \"text\": self.texts[idx],  # 对应的文本\n",
    "                \"metadata\": self.metadata[idx],  # 对应的元数据\n",
    "                \"similarity\": score  # 相似度分数\n",
    "            })\n",
    "\n",
    "        return results\n"
   ],
   "id": "6ea9e141f63122ca",
   "outputs": [],
   "execution_count": 7
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# 使用问题增强处理文档\n",
    "所有步骤整合在一起，用于处理文档、生成问题，并构建增强型向量存储。\n"
   ],
   "id": "17958b69f35e197f"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-24T06:21:47.748250Z",
     "start_time": "2025-04-24T06:21:47.743315Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def process_document(pdf_path, chunk_size=1000, chunk_overlap=200, questions_per_chunk=5):\n",
    "    \"\"\"\n",
    "    处理带有问题增强功能的文档。\n",
    "\n",
    "    Args:\n",
    "    pdf_path (str): PDF 文件的路径。\n",
    "    chunk_size (int): 每个文本块的字符大小。\n",
    "    chunk_overlap (int): 块之间的重叠字符数。\n",
    "    questions_per_chunk (int): 每个块生成的问题数量。\n",
    "\n",
    "    Returns:\n",
    "    Tuple[List[str], SimpleVectorStore]: 文本块列表和向量存储。\n",
    "    \"\"\"\n",
    "    print(\"从 PDF 中提取文本...\")\n",
    "    extracted_text = extract_text_from_pdf(pdf_path)\n",
    "\n",
    "    print(\"分割文本...\")\n",
    "    text_chunks = chunk_text(extracted_text, chunk_size, chunk_overlap)\n",
    "    print(f\"创建了 {len(text_chunks)} 个文本块\")\n",
    "\n",
    "    vector_store = SimpleVectorStore()\n",
    "\n",
    "    print(\"处理文本块并生成问题...\")\n",
    "    for i, chunk in enumerate(tqdm(text_chunks, desc=\"处理文本块\")):\n",
    "        # 为文本块本身创建嵌入\n",
    "        chunk_embedding_response = create_embeddings(chunk)\n",
    "        chunk_embedding = chunk_embedding_response.data[0].embedding\n",
    "\n",
    "        # 将文本块添加到向量存储中\n",
    "        vector_store.add_item(\n",
    "            text=chunk,\n",
    "            embedding=chunk_embedding,\n",
    "            metadata={\"type\": \"chunk\", \"index\": i}\n",
    "        )\n",
    "\n",
    "        # 为该文本块生成问题\n",
    "        questions = generate_questions(chunk, num_questions=questions_per_chunk)\n",
    "\n",
    "        # 为每个问题创建嵌入并添加到向量存储中\n",
    "        for j, question in enumerate(questions):\n",
    "            question_embedding_response = create_embeddings(question)\n",
    "            question_embedding = question_embedding_response.data[0].embedding\n",
    "\n",
    "            # 将问题添加到向量存储中\n",
    "            vector_store.add_item(\n",
    "                text=question,\n",
    "                embedding=question_embedding,\n",
    "                metadata={\"type\": \"question\", \"chunk_index\": i, \"original_chunk\": chunk}\n",
    "            )\n",
    "\n",
    "    return text_chunks, vector_store\n"
   ],
   "id": "5d122f94ed10aece",
   "outputs": [],
   "execution_count": 8
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# 提取和处理文档",
   "id": "28b4cc7b7c055b3a"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-24T06:22:33.665739Z",
     "start_time": "2025-04-24T06:21:52.428468Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# 处理文档（提取文本、创建块、生成问题、构建向量存储）\n",
    "text_chunks, vector_store = process_document(\n",
    "    pdf_path,\n",
    "    chunk_size=1000,  # 每个文本块的字符大小为1000\n",
    "    chunk_overlap=200,  # 块之间的重叠字符数为200\n",
    "    questions_per_chunk=3  # 每个块生成3个问题\n",
    ")\n",
    "\n",
    "# 打印向量存储中的项目数量\n",
    "print(f\"向量存储包含 {len(vector_store.texts)} 个项目\")  # 13*4=52，一个块由1个chunk,3个question组成\n"
   ],
   "id": "55c6672fd3078d12",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "从 PDF 中提取文本...\n",
      "分割文本...\n",
      "创建了 13 个文本块\n",
      "处理文本块并生成问题...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "处理文本块: 100%|██████████| 13/13 [00:41<00:00,  3.17s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "向量存储包含 52 个项目\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "execution_count": 9
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# 语义检索\n",
    "\n",
    "实现了一个语义搜索功能，该功能类似于简单的RAG（检索增强生成）实现，但针对我们的增强型向量存储进行了适应性调整。"
   ],
   "id": "7f87e67195e35151"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-24T06:22:57.420078Z",
     "start_time": "2025-04-24T06:22:57.416777Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def semantic_search(query, vector_store, k=5):\n",
    "    \"\"\"\n",
    "    使用查询和向量存储执行语义搜索。\n",
    "\n",
    "    Args：\n",
    "    query (str): 搜索查询。\n",
    "    vector_store (SimpleVectorStore): 要搜索的向量存储。\n",
    "    k (int): 返回的结果数量。\n",
    "\n",
    "    Returns：\n",
    "    List[Dict]: 最相关的前 k 个结果列表，每个结果包含文本和元数据信息。\n",
    "    \"\"\"\n",
    "    # 为查询创建嵌入\n",
    "    query_embedding_response = create_embeddings(query)\n",
    "    query_embedding = query_embedding_response.data[0].embedding\n",
    "\n",
    "    # 搜索向量存储\n",
    "    results = vector_store.similarity_search(query_embedding, k=k)\n",
    "\n",
    "    return results\n"
   ],
   "id": "ca4fe99f6e0fbd0e",
   "outputs": [],
   "execution_count": 10
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# 在向量存储中搜索问题",
   "id": "38c386868025b43d"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-24T06:23:02.467800Z",
     "start_time": "2025-04-24T06:23:02.049781Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# 加载来自 JSON 文件的验证数据\n",
    "with open('../../data/val.json', mode=\"r\", encoding=\"utf-8\") as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "# 从验证数据中提取第一个查询\n",
    "query = data[0]['question']\n",
    "\n",
    "# 执行语义搜索以找到相关的内容\n",
    "search_results = semantic_search(query, vector_store, k=5)\n",
    "\n",
    "print(\"Query:\", query)  # 打印查询内容\n",
    "print(\"\\nSearch Results:\")  # 打印搜索结果标题\n",
    "\n",
    "# 按类型组织结果\n",
    "chunk_results = []  # 存储文档块的结果\n",
    "question_results = []  # 存储问题的结果\n",
    "\n",
    "# 将搜索结果分为文档块和问题两类\n",
    "for result in search_results:\n",
    "    if result[\"metadata\"][\"type\"] == \"chunk\":  # 如果结果是文档块类型\n",
    "        chunk_results.append(result)\n",
    "    else:  # 如果结果是问题类型\n",
    "        question_results.append(result)\n",
    "\n",
    "# 打印文档块结果\n",
    "print(\"\\nRelevant Document Chunks:\")  # 打印相关文档块标题\n",
    "for i, result in enumerate(chunk_results):\n",
    "    print(f\"Context {i + 1} (similarity: {result['similarity']:.4f}):\")  # 打印每个文档块的相似度分数\n",
    "    print(result[\"text\"][:300] + \"...\")  # 打印文档块的前300个字符\n",
    "    print(\"=====================================\")  # 分隔符\n",
    "\n",
    "# 打印匹配的问题\n",
    "print(\"\\nMatched Questions:\")  # 打印匹配问题标题\n",
    "for i, result in enumerate(question_results):\n",
    "    print(f\"Question {i + 1} (similarity: {result['similarity']:.4f}):\")  # 打印每个问题的相似度分数\n",
    "    print(result[\"text\"])  # 打印问题内容\n",
    "    chunk_idx = result[\"metadata\"][\"chunk_index\"]  # 获取问题所属的文档块索引\n",
    "    print(f\"From chunk {chunk_idx}\")  # 打印问题来源的文档块索引\n",
    "    print(\"=====================================\")  # 分隔符\n"
   ],
   "id": "1e0663365f547905",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Query: 什么是‘可解释人工智能’，为什么它被认为很重要？\n",
      "\n",
      "Search Results:\n",
      "\n",
      "Relevant Document Chunks:\n",
      "\n",
      "Matched Questions:\n",
      "Question 1 (similarity: 0.9426):\n",
      "可解释人工智能(XAI)技术的主要作用是什么？\n",
      "From chunk 10\n",
      "=====================================\n",
      "Question 2 (similarity: 0.9400):\n",
      "可解释人工智能（XAI）的主要目标是什么？\n",
      "From chunk 3\n",
      "=====================================\n",
      "Question 3 (similarity: 0.9400):\n",
      "可解释人工智能（XAI）的研究重点是什么？\n",
      "From chunk 8\n",
      "=====================================\n",
      "Question 4 (similarity: 0.9350):\n",
      "为什么透明度和可解释性对建立人工智能系统的信任至关重要？\n",
      "From chunk 11\n",
      "=====================================\n",
      "Question 5 (similarity: 0.9158):\n",
      "人工智能（AI）的定义是什么？  \n",
      "From chunk 0\n",
      "=====================================\n"
     ]
    }
   ],
   "execution_count": 11
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# 生成响应的上下文\n",
    "现在我们通过结合相关文档块和问题的信息来准备上下文。"
   ],
   "id": "4262eba16e614ef8"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-24T06:23:49.486056Z",
     "start_time": "2025-04-24T06:23:49.481522Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def prepare_context(search_results):\n",
    "    \"\"\"\n",
    "    从语义搜索结果中准备一个统一的上下文，用于生成回答。\n",
    "\n",
    "    Args:\n",
    "    search_results (List[Dict]): 语义搜索的结果。\n",
    "\n",
    "    Returns:\n",
    "    str: 合并后的上下文字符串。\n",
    "    \"\"\"\n",
    "    # 提取结果中引用的独特文档块\n",
    "    chunk_indices = set()\n",
    "    context_chunks = []\n",
    "\n",
    "    # 首先添加直接匹配的文档块\n",
    "    for result in search_results:\n",
    "        if result[\"metadata\"][\"type\"] == \"chunk\":\n",
    "            chunk_indices.add(result[\"metadata\"][\"index\"])\n",
    "            context_chunks.append(f\"Chunk {result['metadata']['index']}:\\n{result['text']}\")\n",
    "\n",
    "    # 然后添加由问题引用的文档块\n",
    "    for result in search_results:\n",
    "        if result[\"metadata\"][\"type\"] == \"question\":\n",
    "            chunk_idx = result[\"metadata\"][\"chunk_index\"]\n",
    "            if chunk_idx not in chunk_indices:\n",
    "                chunk_indices.add(chunk_idx)\n",
    "                context_chunks.append(\n",
    "                    f\"Chunk {chunk_idx} (referenced by question '{result['text']}'):\\n{result['metadata']['original_chunk']}\")\n",
    "\n",
    "    # 合并所有上下文块\n",
    "    full_context = \"\\n\\n\".join(context_chunks)\n",
    "    return full_context\n"
   ],
   "id": "bcf5d92dbb12d858",
   "outputs": [],
   "execution_count": 12
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# 根据检索的文本块生成回答",
   "id": "38dae9df2f0c0c1e"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-24T06:23:53.429559Z",
     "start_time": "2025-04-24T06:23:53.425793Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def generate_response(query, context):\n",
    "    \"\"\"\n",
    "    根据查询和上下文生成回答。\n",
    "\n",
    "    Args:\n",
    "    query (str): 用户的问题。\n",
    "    context (str): 从向量存储中检索到的上下文信息。\n",
    "    model (str): 用于生成回答的模型。\n",
    "\n",
    "    Returns:\n",
    "    str: 生成的回答。\n",
    "    \"\"\"\n",
    "    # 定义系统提示，指导AI助手严格基于给定的上下文进行回答。\n",
    "    # 如果无法直接从提供的上下文中得出答案，则回复：“我没有足够的信息来回答这个问题。”\n",
    "    system_prompt = \"你是一个AI助手，严格根据给定的上下文进行回答。如果无法直接从提供的上下文中得出答案，请回复：'我没有足够的信息来回答这个问题。'\"\n",
    "\n",
    "    # 构建用户提示，包含上下文和问题。\n",
    "    # 要求AI助手仅基于提供的上下文回答问题，并保持简洁和准确。\n",
    "    user_prompt = f\"\"\"\n",
    "        上下文内容:\n",
    "        {context}\n",
    "\n",
    "        问题: {query}\n",
    "\n",
    "        请仅根据上述上下文回答问题, 并保持简明扼要。\n",
    "    \"\"\"\n",
    "\n",
    "    # 使用指定的模型生成回答。\n",
    "    # 设置temperature为0以确保输出更确定和一致。\n",
    "    response = client.chat.completions.create(\n",
    "        model=llm_model,\n",
    "        temperature=0,\n",
    "        messages=[\n",
    "            {\"role\": \"system\", \"content\": system_prompt},\n",
    "            {\"role\": \"user\", \"content\": user_prompt}\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    # 返回生成的回答内容。\n",
    "    return response.choices[0].message.content\n"
   ],
   "id": "7d5e6489e44ed0e",
   "outputs": [],
   "execution_count": 13
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-24T06:24:00.621822Z",
     "start_time": "2025-04-24T06:23:57.263492Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# Prepare context from search results\n",
    "context = prepare_context(search_results)\n",
    "\n",
    "# Generate response\n",
    "response_text = generate_response(query, context)\n",
    "\n",
    "print(\"\\nQuery:\", query)\n",
    "print(\"\\nResponse:\")\n",
    "print(response_text)"
   ],
   "id": "1bcb61c41932020f",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Query: 什么是‘可解释人工智能’，为什么它被认为很重要？\n",
      "\n",
      "Response:\n",
      "可解释人工智能（XAI）旨在使人工智能系统的决策过程更加透明易懂，帮助用户理解其决策方式和依据。它被认为很重要是因为能增强对AI系统的信任、提高问责制，并确保其公平性和准确性。\n"
     ]
    }
   ],
   "execution_count": 14
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# 评估\n",
   "id": "ea782318ac65e7bb"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-24T06:24:11.204688Z",
     "start_time": "2025-04-24T06:24:11.199798Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def evaluate_response(query, response, reference_answer):\n",
    "    \"\"\"\n",
    "    对AI生成的回答进行评估，将其与参考答案进行对比。\n",
    "\n",
    "    Args:\n",
    "    query (str): 用户的问题。\n",
    "    response (str): AI生成的回答。\n",
    "    reference_answer (str): 参考/理想答案。\n",
    "    model (str): 用于评估的模型。\n",
    "\n",
    "    Returns:\n",
    "    str: 评估反馈。\n",
    "    \"\"\"\n",
    "    # 定义评估系统的系统提示\n",
    "    evaluate_system_prompt = \"\"\"您是一个智能评估系统，负责评估AI回答的质量。\n",
    "    请将AI助手的回答与真实/参考答案进行对比，基于以下几点进行评估：\n",
    "        1. 事实正确性 - 回答是否包含准确信息？\n",
    "        2. 完整性 - 是否涵盖参考内容的所有重要方面？\n",
    "        3. 相关性 - 是否直接针对问题作出回应？\n",
    "\n",
    "        请分配0到1之间的评分：\n",
    "        - 1.0：内容与含义完全匹配\n",
    "        - 0.8：非常好，仅有少量遗漏/差异\n",
    "        - 0.6：良好，涵盖主要要点但遗漏部分细节\n",
    "        - 0.4：部分正确答案但存在显著遗漏\n",
    "        - 0.2：仅包含少量相关信息\n",
    "        - 0.0：错误或无关信息\n",
    "\n",
    "    请提供评分并附理由说明。\n",
    "    \"\"\"\n",
    "\n",
    "    # 创建评估提示\n",
    "    # 包含用户问题、AI回答、参考答案以及要求评估的内容。\n",
    "    evaluation_prompt = f\"\"\"\n",
    "        用户问题: {query}\n",
    "\n",
    "        AI回答:\n",
    "        {response}\n",
    "\n",
    "        参考答案:\n",
    "        {reference_answer}\n",
    "\n",
    "        请根据参考答案评估AI的回答。\n",
    "    \"\"\"\n",
    "\n",
    "    # 生成评估结果\n",
    "    # 使用指定的模型生成评估结果。\n",
    "    eval_response = client.chat.completions.create(\n",
    "        model=llm_model,\n",
    "        temperature=0,\n",
    "        messages=[\n",
    "            {\"role\": \"system\", \"content\": evaluate_system_prompt},\n",
    "            {\"role\": \"user\", \"content\": evaluation_prompt}\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    # 返回评估内容\n",
    "    return eval_response.choices[0].message.content\n"
   ],
   "id": "5c0312f83bebb8f1",
   "outputs": [],
   "execution_count": 15
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-24T06:24:19.482538Z",
     "start_time": "2025-04-24T06:24:14.408513Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# Get reference answer from validation data\n",
    "reference_answer = data[0]['ideal_answer']\n",
    "\n",
    "# Evaluate the response\n",
    "evaluation = evaluate_response(query, response_text, reference_answer)\n",
    "\n",
    "print(\"\\nEvaluation:\")\n",
    "print(evaluation)"
   ],
   "id": "771c0373da42f681",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Evaluation:\n",
      "评分：1.0\n",
      "\n",
      "理由说明：\n",
      "AI的回答与参考答案在内容与含义上完全匹配。两者都准确地定义了可解释人工智能（XAI）的核心目标，即提高透明度和可理解性，并强调了其重要性在于建立信任、问责制和确保公平性。AI的回答甚至比参考答案更详细一些，提到了“准确性”，这进一步丰富了回答的内容。因此，该回答完全符合评估标准，应给予满分。\n"
     ]
    }
   ],
   "execution_count": 16
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
